I want to examine regression as a surrogate modeling technique, and through this, explore multivariable regression visualization.
Regression has been under-explored for surrogate modeling in visualization, making this an interesting task for interpretability and explainability for complicated regression models or models that may be inaccessible. This may be due to a general lack of visualization techniques for regression models beyond 2 or 3 dimensions. For this reason, I think it would be interesting to explore additional regression visualization techniques as a part of this project.
Specifically for the surrogate modeling, in order to get better performance from linear regression techniques, I’m using MARS to get better fits from regression lines while maintaining the interpretability of linear regression models.
For the visualization component, I’m exploring visualizations of multiple linear regressions using different layers of data encodings, and I’m also exploring nomograms as a visualization/interpretability technique for linear regressions.
I’m combining these two ideas by incorporating MARS into nomograms, presented alongside their formulations.
boston <- read_csv("boston.csv")
## Rows: 506 Columns: 14
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## dbl (14): CRIM, ZN, INDUS, CHAS, NOX, RM, AGE, DIS, RAD, TAX, PTRATIO, B, LS...
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
boston.adj <- boston %>%
mutate(id = 1:nrow(boston))
summary(boston.adj)
## CRIM ZN INDUS CHAS
## Min. : 0.00632 Min. : 0.00 Min. : 0.46 Min. :0.00000
## 1st Qu.: 0.08205 1st Qu.: 0.00 1st Qu.: 5.19 1st Qu.:0.00000
## Median : 0.25651 Median : 0.00 Median : 9.69 Median :0.00000
## Mean : 3.61352 Mean : 11.36 Mean :11.14 Mean :0.06917
## 3rd Qu.: 3.67708 3rd Qu.: 12.50 3rd Qu.:18.10 3rd Qu.:0.00000
## Max. :88.97620 Max. :100.00 Max. :27.74 Max. :1.00000
## NOX RM AGE DIS
## Min. :0.3850 Min. :3.561 Min. : 2.90 Min. : 1.130
## 1st Qu.:0.4490 1st Qu.:5.886 1st Qu.: 45.02 1st Qu.: 2.100
## Median :0.5380 Median :6.208 Median : 77.50 Median : 3.207
## Mean :0.5547 Mean :6.285 Mean : 68.57 Mean : 3.795
## 3rd Qu.:0.6240 3rd Qu.:6.623 3rd Qu.: 94.08 3rd Qu.: 5.188
## Max. :0.8710 Max. :8.780 Max. :100.00 Max. :12.127
## RAD TAX PTRATIO B
## Min. : 1.000 Min. :187.0 Min. :12.60 Min. : 0.32
## 1st Qu.: 4.000 1st Qu.:279.0 1st Qu.:17.40 1st Qu.:375.38
## Median : 5.000 Median :330.0 Median :19.05 Median :391.44
## Mean : 9.549 Mean :408.2 Mean :18.46 Mean :356.67
## 3rd Qu.:24.000 3rd Qu.:666.0 3rd Qu.:20.20 3rd Qu.:396.23
## Max. :24.000 Max. :711.0 Max. :22.00 Max. :396.90
## LSTAT MEDV id
## Min. : 1.73 Min. : 5.00 Min. : 1.0
## 1st Qu.: 6.95 1st Qu.:17.02 1st Qu.:127.2
## Median :11.36 Median :21.20 Median :253.5
## Mean :12.65 Mean :22.53 Mean :253.5
## 3rd Qu.:16.95 3rd Qu.:25.00 3rd Qu.:379.8
## Max. :37.97 Max. :50.00 Max. :506.0
pairs(boston.adj)
correlationMatrix <- cor(boston)
correlationMatrix
## CRIM ZN INDUS CHAS NOX
## CRIM 1.00000000 -0.20046922 0.40658341 -0.055891582 0.42097171
## ZN -0.20046922 1.00000000 -0.53382819 -0.042696719 -0.51660371
## INDUS 0.40658341 -0.53382819 1.00000000 0.062938027 0.76365145
## CHAS -0.05589158 -0.04269672 0.06293803 1.000000000 0.09120281
## NOX 0.42097171 -0.51660371 0.76365145 0.091202807 1.00000000
## RM -0.21924670 0.31199059 -0.39167585 0.091251225 -0.30218819
## AGE 0.35273425 -0.56953734 0.64477851 0.086517774 0.73147010
## DIS -0.37967009 0.66440822 -0.70802699 -0.099175780 -0.76923011
## RAD 0.62550515 -0.31194783 0.59512927 -0.007368241 0.61144056
## TAX 0.58276431 -0.31456332 0.72076018 -0.035586518 0.66802320
## PTRATIO 0.28994558 -0.39167855 0.38324756 -0.121515174 0.18893268
## B -0.38506394 0.17552032 -0.35697654 0.048788485 -0.38005064
## LSTAT 0.45562148 -0.41299457 0.60379972 -0.053929298 0.59087892
## MEDV -0.38830461 0.36044534 -0.48372516 0.175260177 -0.42732077
## RM AGE DIS RAD TAX PTRATIO
## CRIM -0.21924670 0.35273425 -0.37967009 0.625505145 0.58276431 0.2899456
## ZN 0.31199059 -0.56953734 0.66440822 -0.311947826 -0.31456332 -0.3916785
## INDUS -0.39167585 0.64477851 -0.70802699 0.595129275 0.72076018 0.3832476
## CHAS 0.09125123 0.08651777 -0.09917578 -0.007368241 -0.03558652 -0.1215152
## NOX -0.30218819 0.73147010 -0.76923011 0.611440563 0.66802320 0.1889327
## RM 1.00000000 -0.24026493 0.20524621 -0.209846668 -0.29204783 -0.3555015
## AGE -0.24026493 1.00000000 -0.74788054 0.456022452 0.50645559 0.2615150
## DIS 0.20524621 -0.74788054 1.00000000 -0.494587930 -0.53443158 -0.2324705
## RAD -0.20984667 0.45602245 -0.49458793 1.000000000 0.91022819 0.4647412
## TAX -0.29204783 0.50645559 -0.53443158 0.910228189 1.00000000 0.4608530
## PTRATIO -0.35550149 0.26151501 -0.23247054 0.464741179 0.46085304 1.0000000
## B 0.12806864 -0.27353398 0.29151167 -0.444412816 -0.44180801 -0.1773833
## LSTAT -0.61380827 0.60233853 -0.49699583 0.488676335 0.54399341 0.3740443
## MEDV 0.69535995 -0.37695457 0.24992873 -0.381626231 -0.46853593 -0.5077867
## B LSTAT MEDV
## CRIM -0.38506394 0.4556215 -0.3883046
## ZN 0.17552032 -0.4129946 0.3604453
## INDUS -0.35697654 0.6037997 -0.4837252
## CHAS 0.04878848 -0.0539293 0.1752602
## NOX -0.38005064 0.5908789 -0.4273208
## RM 0.12806864 -0.6138083 0.6953599
## AGE -0.27353398 0.6023385 -0.3769546
## DIS 0.29151167 -0.4969958 0.2499287
## RAD -0.44441282 0.4886763 -0.3816262
## TAX -0.44180801 0.5439934 -0.4685359
## PTRATIO -0.17738330 0.3740443 -0.5077867
## B 1.00000000 -0.3660869 0.3334608
## LSTAT -0.36608690 1.0000000 -0.7376627
## MEDV 0.33346082 -0.7376627 1.0000000
This data is pretty clean for these purposes, and relatively small, meaning it will be easier to work with and the models run quickly, which is just helpful. Generally, because of the correlations, I’m going to be pulling INDUS, AGE, and TAX from the linear models and models for comparison. Not all of the data is distributed fantastically, but it’s generally not too terrible.
set.seed(497562)
train <- boston.adj %>% sample_frac(0.7)
test <- anti_join(boston.adj, train, by = "id")
summary(train$MEDV)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 5.00 17.50 21.40 22.75 25.15 50.00
summary(test$MEDV)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 5.00 15.78 20.35 22.03 24.57 50.00
RFModel <- randomForest(MEDV~.-id, data=train, importance=T, proximity=T)
print(RFModel)
##
## Call:
## randomForest(formula = MEDV ~ . - id, data = train, importance = T, proximity = T)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 4
##
## Mean of squared residuals: 11.64188
## % Var explained: 85.68
RFModel.simple <- randomForest(MEDV~LSTAT, data=train, importance=T, proximity=T)
print(RFModel.simple)
##
## Call:
## randomForest(formula = MEDV ~ LSTAT, data = train, importance = T, proximity = T)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 1
##
## Mean of squared residuals: 38.00409
## % Var explained: 53.25
Fit is decent for no clean-up, feature selection, etc. And as expected, the simpler model is much worse, but still above 50% explained, so it’s doing something.
LM.simple <- lm(MEDV ~ LSTAT, data = train)
summary(LM.simple)
##
## Call:
## lm(formula = MEDV ~ LSTAT, data = train)
##
## Residuals:
## Min 1Q Median 3Q Max
## -15.076 -3.872 -1.270 1.596 24.572
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 34.36559 0.67710 50.75 <2e-16 ***
## LSTAT -0.93781 0.04766 -19.68 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 6.24 on 352 degrees of freedom
## Multiple R-squared: 0.5238, Adjusted R-squared: 0.5224
## F-statistic: 387.2 on 1 and 352 DF, p-value: < 2.2e-16
ggplot(data = train, aes(x = LSTAT, y = MEDV)) + geom_point() + geom_smooth(method = "lm", se = F) + theme_minimal()
## `geom_smooth()` using formula 'y ~ x'
LM.simpleLog <- lm(MEDV ~ log(LSTAT), data = train)
summary(LM.simpleLog)
##
## Call:
## lm(formula = MEDV ~ log(LSTAT), data = train)
##
## Residuals:
## Min 1Q Median 3Q Max
## -14.4454 -3.5303 -0.5937 2.1272 26.0048
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 51.8645 1.1825 43.86 <2e-16 ***
## log(LSTAT) -12.3619 0.4872 -25.37 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 5.376 on 352 degrees of freedom
## Multiple R-squared: 0.6465, Adjusted R-squared: 0.6455
## F-statistic: 643.8 on 1 and 352 DF, p-value: < 2.2e-16
ggplot(data = train, aes(x = log(LSTAT), y = MEDV)) + geom_point() + geom_smooth(method = "lm", se = F) + theme_minimal()
## `geom_smooth()` using formula 'y ~ x'
LM.full <- lm(MEDV ~ .-id-INDUS-AGE-TAX, data = train)
summary(LM.full)
##
## Call:
## lm(formula = MEDV ~ . - id - INDUS - AGE - TAX, data = train)
##
## Residuals:
## Min 1Q Median 3Q Max
## -11.3762 -2.9419 -0.3283 1.7906 25.1764
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 29.470325 6.271687 4.699 3.79e-06 ***
## CRIM -0.147547 0.039537 -3.732 0.000222 ***
## ZN 0.046285 0.015647 2.958 0.003310 **
## CHAS 3.190066 1.014763 3.144 0.001814 **
## NOX -16.413657 4.213903 -3.895 0.000118 ***
## RM 4.063011 0.498869 8.144 7.16e-15 ***
## DIS -1.477538 0.221431 -6.673 1.01e-10 ***
## RAD 0.176183 0.048273 3.650 0.000303 ***
## PTRATIO -0.989909 0.150357 -6.584 1.72e-10 ***
## B 0.015271 0.003241 4.712 3.57e-06 ***
## LSTAT -0.539329 0.057368 -9.401 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 4.702 on 343 degrees of freedom
## Multiple R-squared: 0.7364, Adjusted R-squared: 0.7288
## F-statistic: 95.84 on 10 and 343 DF, p-value: < 2.2e-16
LM.simple2var <- lm(MEDV ~ LSTAT + RM, data = train)
summary(LM.simple2var)
##
## Call:
## lm(formula = MEDV ~ LSTAT + RM, data = train)
##
## Residuals:
## Min 1Q Median 3Q Max
## -12.8420 -3.4141 -0.9693 1.8219 28.3297
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -2.25344 3.88360 -0.580 0.562
## LSTAT -0.63658 0.05296 -12.020 <2e-16 ***
## RM 5.23119 0.54803 9.545 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 5.567 on 351 degrees of freedom
## Multiple R-squared: 0.6219, Adjusted R-squared: 0.6198
## F-statistic: 288.7 on 2 and 351 DF, p-value: < 2.2e-16
ggplot(data = train, aes(x = LSTAT, y = MEDV)) + geom_point() + geom_smooth(method = "lm", se = F) + theme_minimal()
## `geom_smooth()` using formula 'y ~ x'
From some initial setup, I was able to recreate the visualization I saw that used alternate encodings beyond x and y for additional variables. This is 3 independent variables: one for the x-axis, one for the colorscale, and one for the faceting
ggplot(data = train, aes(x = LSTAT, y = MEDV)) +
geom_point(aes(color = RM)) +
geom_smooth(method = "lm", se = F) +
facet_wrap(vars(CHAS)) + theme_minimal()
## `geom_smooth()` using formula 'y ~ x'
From this implementation, I was interested in trying to cram as many variables into a single plot through various encodings as I could. The ggplot2 package in R is relatively strict to following the Grammar of Graphics framework, meaning while you can have lots of different scales, you can only have one of each unless you really hack it (meaning you can’t do those kind-of terrible plots with two y-axes that people make in Excel a lot unless you try really hard because it wants to guide you to good visualization practices). So for these plots, I was experimenting with one variable to each encoding, meaning some variables had to be grouped in order to make them make sense for the encoding.
train.adj <- train
train.adj$DIS.group <- cut(log(train.adj$DIS), 3)
train.adj <- train.adj %>%
mutate(DIS.group2 = as.factor(DIS.group),
DIS.group2 = if_else(DIS.group2 == "(0.12,0.872]", "first",
if_else(DIS.group2 == "(0.872,1.62]", "second", "third")))
train.adj$PTR.group <- cut(train.adj$PTRATIO, 3)
train.adj <- train.adj %>%
mutate(PTR.group2 = as.factor(PTR.group),
PTR.group2 = if_else(PTR.group2 == "(12.6,15.7]", "first",
if_else(PTR.group2 == "(15.7,18.9]", "second", "third")))
ggplot(data = train.adj, aes(x = LSTAT, y = MEDV)) +
geom_point(aes(color = RM, shape = DIS.group2)) +
geom_smooth(aes(group = DIS.group2, linetype = DIS.group2), method = "lm", se = F) +
facet_wrap(vars(CHAS)) +
scale_color_distiller(palette = "Reds") + theme_minimal() +
ggtitle("4 Independent Variables")
## `geom_smooth()` using formula 'y ~ x'
ggplot(data = train.adj, aes(x = LSTAT, y = MEDV)) +
geom_point(aes(color = RM, shape = DIS.group2, size = NOX)) +
geom_smooth(aes(group = DIS.group2, linetype = DIS.group2), method = "lm", se = F) +
facet_wrap(vars(CHAS)) +
scale_color_distiller(palette = "Reds") + theme_minimal() +
ggtitle("5 Independent Variables")
## `geom_smooth()` using formula 'y ~ x'
ggplot(data = train.adj, aes(x = LSTAT, y = MEDV)) +
geom_point(aes(color = RM, shape = DIS.group2, size = NOX, alpha = B)) +
geom_smooth(aes(group = DIS.group2, linetype = DIS.group2), method = "lm", se = F) +
facet_wrap(vars(CHAS)) +
scale_color_distiller(palette = "Reds") + theme_minimal() +
ggtitle("6 Independent Variables")
## `geom_smooth()` using formula 'y ~ x'
ggplot(data = train.adj, aes(x = LSTAT, y = MEDV)) +
geom_point(aes(color = RM, shape = DIS.group2, size = NOX, alpha = B)) +
geom_smooth(aes(group = DIS.group2, linetype = DIS.group2), method = "lm", se = F) +
facet_wrap(vars(CHAS, PTR.group)) +
scale_color_distiller(palette = "Reds") + theme_minimal() +
ggtitle("7 Independent Variables")
## `geom_smooth()` using formula 'y ~ x'
So that’s kind of a garbage plot, there’s too much information, but I think it was a fun thought experiment nonetheless. With data with slightly more balanced (and senseable) categorical variables, there may be some potential for a (reduced) version of exploring some of these kinds of encodings. With this level of abstraction, you lose a lot of the specificity of something like a completely broken-out faceted-by-variable scatterplot, but you do retain the general gist with a lot of possible ways to compare.
In order to train the surrogate model, I need my predictions attached to my data. I’m going to use the full dataset because the sample sizes are so small, so this includes the training and testing data, which is probably not ideal, but will work for this purpose well enough.
boston.pred <- boston.adj %>% mutate(pred = predict(RFModel, boston.adj))
ggplot(data = boston.pred, aes(x = MEDV, y = pred)) + geom_point() + geom_line(aes(x = MEDV, y = MEDV)) + theme_minimal()
cor(boston.pred$MEDV, boston.pred$pred)
## [1] 0.9760775
To begin with the sample case and to provide a point of comparison to the MARS model, I made linear regressions surrogate models on the predictions. The coverage is pretty good, even with just a basic multiple linear regression with no interaction terms.
surrModelFull <- lm(pred ~ .-id-MEDV-INDUS-AGE-TAX, data = boston.pred)
summary(surrModelFull)
##
## Call:
## lm(formula = pred ~ . - id - MEDV - INDUS - AGE - TAX, data = boston.pred)
##
## Residuals:
## Min 1Q Median 3Q Max
## -7.4181 -2.3456 -0.5155 1.2552 21.7838
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 29.584835 3.894248 7.597 1.53e-13 ***
## CRIM -0.069040 0.025285 -2.730 0.006551 **
## ZN 0.038334 0.010236 3.745 0.000202 ***
## CHAS 2.878597 0.656958 4.382 1.44e-05 ***
## NOX -16.023065 2.649941 -6.047 2.92e-09 ***
## RM 3.894972 0.311167 12.517 < 2e-16 ***
## DIS -1.127182 0.142653 -7.902 1.80e-14 ***
## RAD 0.066788 0.031128 2.146 0.032392 *
## PTRATIO -0.872922 0.098454 -8.866 < 2e-16 ***
## B 0.007756 0.002062 3.762 0.000189 ***
## LSTAT -0.475590 0.036579 -13.002 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 3.655 on 495 degrees of freedom
## Multiple R-squared: 0.8037, Adjusted R-squared: 0.7998
## F-statistic: 202.7 on 10 and 495 DF, p-value: < 2.2e-16
surrModelFull.simple <- lm(pred ~ LSTAT, data = boston.pred)
summary(surrModelFull.simple)
##
## Call:
## lm(formula = pred ~ LSTAT, data = boston.pred)
##
## Residuals:
## Min 1Q Median 3Q Max
## -10.024 -3.373 -1.430 1.649 18.154
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 33.81857 0.46499 72.73 <2e-16 ***
## LSTAT -0.88987 0.03201 -27.80 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 5.137 on 504 degrees of freedom
## Multiple R-squared: 0.6052, Adjusted R-squared: 0.6045
## F-statistic: 772.7 on 1 and 504 DF, p-value: < 2.2e-16
ggplot(data = boston.pred, aes(x = LSTAT, y = pred)) + geom_point() + geom_smooth(method = "lm", se = F) + theme_minimal()
## `geom_smooth()` using formula 'y ~ x'
boston.pred <- boston.pred %>%
mutate(lmPred = predict(surrModelFull, boston.pred),
lmPred.simple = predict(surrModelFull.simple, boston.pred))
ggplot(data = boston.pred, aes(x = MEDV, y = lmPred)) + geom_point() + geom_line(aes(x = MEDV, y = MEDV)) + theme_minimal()
ggplot(data = boston.pred, aes(x = MEDV, y = lmPred.simple)) + geom_point() + geom_line(aes(x = MEDV, y = MEDV)) + theme_minimal()
But the MARS models are better. Because the data is small, but also because MARS models are generally fast to run, I can tune the model with a grid search and multiple levels of interaction terms. And this gives me a 0.93 \(R^2\) with our complex final MARS model. I can also look at the most important parameters in the model, which align pretty closely with what I’d seen previously in the linear models.
hyper_grid <- expand.grid(degree = 1:3,
nprune = seq(2, 50, length.out = 10) %>%
floor())
set.seed(94856)
marsModel <- train(
x = subset(boston.pred, select = -c(id, pred, MEDV, lmPred, lmPred.simple)),
y = boston.pred$pred,
method = "earth",
metric = "RMSE",
trControl = trainControl(method = "cv", number = 10),
tuneGrid = hyper_grid)
marsModel$results
## degree nprune RMSE Rsquared MAE RMSESD RsquaredSD MAESD
## 1 1 2 5.240547 0.6022545 3.893072 0.7610958 0.07410151 0.4649393
## 11 2 2 5.399565 0.5776460 4.010813 0.7087632 0.07778127 0.5054217
## 21 3 2 5.145617 0.6055656 3.814669 0.7541360 0.07245564 0.4524907
## 2 1 7 2.837279 0.8870479 1.968647 0.6876708 0.03418364 0.3792677
## 12 2 7 2.801380 0.8872451 1.933505 0.8433673 0.05685955 0.3501053
## 22 3 7 2.668047 0.8947051 1.918222 0.8655001 0.05561735 0.3810067
## 3 1 12 2.487295 0.9102122 1.713152 0.7922868 0.04308861 0.3380115
## 13 2 12 2.429462 0.9119412 1.683824 1.0598253 0.06238840 0.4293920
## 23 3 12 2.339950 0.9176964 1.611285 1.1078434 0.06517203 0.4621955
## 4 1 18 2.227792 0.9284172 1.572493 0.6319530 0.03121878 0.3310970
## 14 2 18 2.172132 0.9300028 1.506917 0.9135359 0.04595103 0.3959505
## 24 3 18 2.135172 0.9313834 1.469639 0.9264824 0.04696174 0.4060596
## 5 1 23 2.232112 0.9283509 1.574360 0.6222758 0.03073575 0.3249333
## 15 2 23 2.158154 0.9310287 1.444062 0.9686420 0.04708315 0.3936845
## 25 3 23 2.136549 0.9322186 1.421543 0.9905438 0.04807447 0.4088671
## 6 1 28 2.232112 0.9283509 1.574360 0.6222758 0.03073575 0.3249333
## 16 2 28 2.158154 0.9310287 1.444062 0.9686420 0.04708315 0.3936845
## 26 3 28 2.130001 0.9324470 1.415808 0.9950593 0.04825597 0.4122893
## 7 1 34 2.232112 0.9283509 1.574360 0.6222758 0.03073575 0.3249333
## 17 2 34 2.158154 0.9310287 1.444062 0.9686420 0.04708315 0.3936845
## 27 3 34 2.130001 0.9324470 1.415808 0.9950593 0.04825597 0.4122893
## 8 1 39 2.232112 0.9283509 1.574360 0.6222758 0.03073575 0.3249333
## 18 2 39 2.158154 0.9310287 1.444062 0.9686420 0.04708315 0.3936845
## 28 3 39 2.130001 0.9324470 1.415808 0.9950593 0.04825597 0.4122893
## 9 1 44 2.232112 0.9283509 1.574360 0.6222758 0.03073575 0.3249333
## 19 2 44 2.158154 0.9310287 1.444062 0.9686420 0.04708315 0.3936845
## 29 3 44 2.130001 0.9324470 1.415808 0.9950593 0.04825597 0.4122893
## 10 1 50 2.232112 0.9283509 1.574360 0.6222758 0.03073575 0.3249333
## 20 2 50 2.158154 0.9310287 1.444062 0.9686420 0.04708315 0.3936845
## 30 3 50 2.130001 0.9324470 1.415808 0.9950593 0.04825597 0.4122893
marsModel$results %>%
filter(nprune==marsModel$bestTune$nprune, degree==marsModel$bestTune$degree)
## degree nprune RMSE Rsquared MAE RMSESD RsquaredSD MAESD
## 1 3 28 2.130001 0.932447 1.415808 0.9950593 0.04825597 0.4122893
ggplot(marsModel) + theme_minimal()
summary(marsModel)
## Call: earth(x=tbl_df[506,13], y=c(25.72,22.19,3...), keepxy=TRUE, degree=3,
## nprune=28)
##
## coefficients
## (Intercept) 23.17273
## h(18.811-CRIM) 0.16305
## h(CRIM-18.811) -0.04299
## h(6.395-RM) -1.89017
## h(RM-6.395) 11.01381
## h(RM-7.929) -8.47393
## h(1.3567-DIS) 645.47789
## h(DIS-1.3567) -0.41676
## h(14.7-PTRATIO) 0.99615
## h(LSTAT-6.36) -0.58456
## h(1.3567-DIS) * B -1.60401
## h(0.713-NOX) * h(LSTAT-6.36) 1.35024
## h(RM-6.395) * h(RAD-4) -0.32440
## h(RM-6.395) * h(4-RAD) -0.76108
## h(6.395-RM) * h(LSTAT-18.76) 0.33951
## h(6.395-RM) * h(18.76-LSTAT) -0.30618
## h(AGE-20.1) * h(PTRATIO-14.7) -0.00731
## h(305-TAX) * h(6.36-LSTAT) 0.02317
## h(TAX-305) * h(6.36-LSTAT) 0.01794
## h(6.395-RM) * h(TAX-224) * h(18.76-LSTAT) 0.00127
## h(6.395-RM) * h(224-TAX) * h(18.76-LSTAT) 0.04484
##
## Selected 21 of 25 terms, and 10 of 13 predictors (nprune=28)
## Termination condition: Reached nk 27
## Importance: RM, LSTAT, RAD, TAX, DIS, AGE, PTRATIO, NOX, B, CRIM, ...
## Number of terms at each degree of interaction: 1 9 9 2
## GCV 2.939801 RSS 1202.791 GRSq 0.9560247 RSq 0.9643016
p1 <- vip(marsModel, num_features = 40, geom = "point", value = "gcv") + ggtitle("GCV") + theme_minimal()
p2 <- vip(marsModel, num_features = 40, geom = "point", value = "rss") + ggtitle("RSS") + theme_minimal()
gridExtra::grid.arrange(p1, p2, ncol = 2)
boston.pred <- boston.pred %>%
mutate(marsPred = predict(marsModel, boston.pred))
ggplot(data = boston.pred, aes(x = MEDV, y = marsPred)) + geom_point() + geom_line(aes(x = MEDV, y = MEDV)) + theme_minimal()
ggplot(data = boston.pred, aes(x = LSTAT, y = marsPred)) + geom_point() + theme_minimal()
# imp = varImp(marsModel)
# imp = imp$importance %>%
# as.data.frame() %>%
# mutate( variable = row.names(.) ) %>%
# filter( Overall > 0 )
# plotmo(marsModel, all1 = T)
I also made a simple MARS model with only the same variable as the linear regression above, as a point of comparison, which is, of course, also substantially better.
marsModel.simple <- train(
x = subset(boston.pred, select = LSTAT),
y = boston.pred$pred,
method = "earth",
metric = "RMSE",
trControl = trainControl(method = "cv", number = 10),
tuneGrid = hyper_grid)
marsModel.simple$results %>%
filter(nprune==marsModel.simple$bestTune$nprune, degree==marsModel.simple$bestTune$degree)
## degree nprune RMSE Rsquared MAE RMSESD RsquaredSD MAESD
## 1 1 7 3.995258 0.7514938 2.888066 0.9022384 0.117816 0.4300581
marsModel.simple$finalModel
## Selected 4 of 4 terms, and 1 of 1 predictors (nprune=7)
## Termination condition: RSq changed by less than 0.001 at 4 terms
## Importance: LSTAT
## Number of terms at each degree of interaction: 1 3 (additive model)
## GCV 16.34901 RSS 8045.295 GRSq 0.7554416 RSq 0.7612184
boston.pred <- boston.pred %>%
mutate(marsPred.simple = predict(marsModel.simple, boston.pred))
ggplot(data = boston.pred, aes(x = MEDV, y = marsPred.simple)) + geom_point() + geom_line(aes(x = MEDV, y = MEDV)) + theme_minimal()
comp1 <- ggplot(data = boston.pred, aes(x = LSTAT, y = marsPred.simple)) + geom_point() + theme_minimal()
comp2 <- ggplot(data = boston.pred, aes(x = LSTAT, y = MEDV)) + geom_point() + theme_minimal()
gridExtra::grid.arrange(comp1, comp2, ncol = 2)
For reasons that I’ll explain later, I’m also going to make a normalized version of all this data. So all the data will be scaled to be from 0 to 1 from whatever range it was previously in. This is going to be really useful for comparing between variables later. Additionally, I can get a broken-out plot of our MARS splines, which does help with the intuition and checking for the splines.
train.pred <- train %>% mutate(pred = predict(RFModel, train))
process.pred <- preProcess(as.data.frame(train.pred), method=c("range"))
train.pred.norm <- bind_cols(predict(process.pred, as.data.frame(train.pred)) %>% select(-c(MEDV, id, pred)),
train.pred[14:16])
marsModel.norm <- train(
x = subset(train.pred.norm, select = -c(id,INDUS,AGE,TAX, MEDV, pred)),
y = train.pred.norm$pred,
method = "earth",
metric = "RMSE",
trControl = trainControl(method = "cv", number = 10),
tuneGrid = hyper_grid)
marsModel.norm$results %>%
filter(nprune==marsModel.norm$bestTune$nprune, degree==marsModel.norm$bestTune$degree)
## degree nprune RMSE Rsquared MAE RMSESD RsquaredSD MAESD
## 1 1 18 2.477986 0.9072029 1.855953 0.5268922 0.05946177 0.3307444
marsModel.norm$finalModel
## Selected 17 of 18 terms, and 8 of 10 predictors (nprune=18)
## Termination condition: Reached nk 21
## Importance: RM, LSTAT, CRIM, PTRATIO, DIS, NOX, CHAS, B, ZN-unused, ...
## Number of terms at each degree of interaction: 1 16 (additive model)
## GCV 5.91038 RSS 1720.371 GRSq 0.9125132 RSq 0.9276559
# ggplot(marsModel.norm)
p1.marsnorm <- vip(marsModel.norm, num_features = 40, geom = "point", value = "gcv") + ggtitle("GCV") + theme_minimal()
p2.marsnorm <- vip(marsModel.norm, num_features = 40, geom = "point", value = "rss") + ggtitle("RSS") + theme_minimal()
gridExtra::grid.arrange(p1.marsnorm, p2.marsnorm, ncol = 2)
terms <- marsModel.norm$finalModel$coefficients %>%
as.data.frame() %>%
mutate( parameter = row.names(.)) %>%
select( parameter, coef = y ) %>%
mutate(spline = as.numeric(str_extract(parameter, "\\-?\\d*\\.\\d*")),
var = str_extract(parameter, "[A-Z]+"),
var = if_else(var == "I", "Intercept", var),
side = if_else(str_detect(parameter, "h\\(\\-?[0-9]"), "left",
if_else(str_detect(parameter, "h\\([A-Z]"),"right","")),
weight = if_else(side == "left", spline,
if_else(side == "right", 1-spline, 0)),
weight = if_else(var == "DIS" & weight == 0.1347920, 0.0838145, weight),
weight = if_else(var == "CRIM" & weight == 0.9450162, 0.1653472, weight),
splineWeight = coef*weight)
rownames(terms) <- NULL
imp = varImp(marsModel.norm)
imp = imp$importance %>%
as.data.frame() %>%
mutate( variable = row.names(.) ) %>%
filter( Overall > 0 )
plotmo(marsModel.norm, all1 = T)
## plotmo grid: CRIM ZN CHAS NOX RM DIS RAD PTRATIO
## 0.002661631 0 0 0.3148148 0.5109858 0.2313401 0.173913 0.6595745
## B LSTAT
## 0.9861188 0.2523455
The main package for creating nomograms is rms, which, as far as I
could find, only lets nomograms be made with its functions, which is
very frustrating. Its OLS function is the same as the lm()
function in base R, so I’m going to use that for the proof of concept
here. First, I’m showing a nomogram of a single value, which is very
similar to the original nomograms for drawing between two values.
Nomograms, like some of the other plots for evaluating ML techniques,
also show the relationship between variables (as in positive and
negative relationships) pretty plainly.
OLSsegment <- ols(MEDV~LSTAT, data = train)
OLSsegment$coefficients
## Intercept LSTAT
## 34.3655950 -0.9378105
ddist <- datadist(train)
options(datadist = "ddist")
plot(nomogram(OLSsegment))
OLSsegment
## Linear Regression Model
##
## ols(formula = MEDV ~ LSTAT, data = train)
##
## Model Likelihood Discrimination
## Ratio Test Indexes
## Obs 354 LR chi2 262.62 R2 0.524
## sigma6.2395 d.f. 1 R2 adj 0.522
## d.f. 352 Pr(> chi2) 0.0000 g 7.162
##
## Residuals
##
## Min 1Q Median 3Q Max
## -15.076 -3.872 -1.270 1.596 24.572
##
##
## Coef S.E. t Pr(>|t|)
## Intercept 34.3656 0.6771 50.75 <0.0001
## LSTAT -0.9378 0.0477 -19.68 <0.0001
##
If you add another variable, you get something similar, but the total points scale begins to shift.
OLSsegment2 <- ols(MEDV~LSTAT + RM, data = train)
plot(nomogram(OLSsegment2))
checktrain <- train %>% select(-c(id,INDUS,AGE,TAX))
OLSfull <- ols(formula = MEDV~., data = checktrain)
OLSfull$coefficients
## Intercept CRIM ZN CHAS NOX RM
## 29.47032477 -0.14754683 0.04628502 3.19006598 -16.41365708 4.06301145
## DIS RAD PTRATIO B LSTAT
## -1.47753760 0.17618272 -0.98990886 0.01527127 -0.53932865
ddist <- datadist(checktrain)
options(datadist = "ddist")
plot(nomogram(OLSfull))
I also wanted to look at splines in nomograms, which rms supports. However, it doesn’t support MARS implementation, so the actual method of putting together this nomogram will have to be somewhat hacked together.
splineTest = ols(MEDV~lsp(LSTAT, 6) + RM, data = checktrain)
splineTest
## Linear Regression Model
##
## ols(formula = MEDV ~ lsp(LSTAT, 6) + RM, data = checktrain)
##
## Model Likelihood Discrimination
## Ratio Test Indexes
## Obs 354 LR chi2 423.83 R2 0.698
## sigma4.9831 d.f. 3 R2 adj 0.695
## d.f. 350 Pr(> chi2) 0.0000 g 8.018
##
## Residuals
##
## Min 1Q Median 3Q Max
## -11.6482 -2.7341 -0.4578 2.1593 28.0631
##
##
## Coef S.E. t Pr(>|t|)
## Intercept 29.4187 4.8439 6.07 <0.0001
## LSTAT -4.5479 0.4193 -10.85 <0.0001
## LSTAT' 4.0242 0.4286 9.39 <0.0001
## RM 3.6279 0.5194 6.98 <0.0001
##
plot(nomogram(splineTest))
We can also make a normalized nonogram, using the normalized data from above., which excitingly, is scaled by the variable importance we saw earlier.
OLSfull.norm <- ols(pred~., data = train.pred.norm %>% select(-c(id,INDUS,AGE,TAX,MEDV)))
OLSfull.norm$coefficients
## Intercept CRIM ZN CHAS NOX RM DIS
## 22.759437 -9.306302 4.214486 2.649818 -7.441217 20.288769 -12.052393
## RAD PTRATIO B LSTAT
## 2.849557 -8.722674 5.047615 -18.061111
ddist <- datadist(train.pred.norm %>% select(-c(id,INDUS,AGE,TAX,MEDV)))
options(datadist = "ddist")
plot(nomogram(OLSfull.norm))
OLSfull.norm
## Linear Regression Model
##
## ols(formula = pred ~ ., data = train.pred.norm %>% select(-c(id,
## INDUS, AGE, TAX, MEDV)))
##
## Model Likelihood Discrimination
## Ratio Test Indexes
## Obs 354 LR chi2 553.58 R2 0.791
## sigma3.8097 d.f. 10 R2 adj 0.785
## d.f. 343 Pr(> chi2) 0.0000 g 8.177
##
## Residuals
##
## Min 1Q Median 3Q Max
## -8.1875 -2.2951 -0.5442 1.3309 20.5132
##
##
## Coef S.E. t Pr(>|t|)
## Intercept 22.7594 2.3284 9.77 <0.0001
## CRIM -9.3063 2.8499 -3.27 0.0012
## ZN 4.2145 1.2677 3.32 0.0010
## CHAS 2.6498 0.8221 3.22 0.0014
## NOX -7.4412 1.6592 -4.48 <0.0001
## RM 20.2888 2.0787 9.76 <0.0001
## DIS -12.0524 1.7188 -7.01 <0.0001
## RAD 2.8496 0.8995 3.17 0.0017
## PTRATIO -8.7227 1.1451 -7.62 <0.0001
## B 5.0476 1.0413 4.85 <0.0001
## LSTAT -18.0611 1.6844 -10.72 <0.0001
##
terms <- marsModel.norm$finalModel$coefficients %>%
as.data.frame() %>%
mutate( parameter = row.names(.)) %>%
select( parameter, coef = y ) %>%
mutate(spline = as.numeric(str_extract(parameter, "\\d*\\.\\d*")),
var = str_extract(parameter, "[A-Z]+"),
var = if_else(var == "I", "Intercept", var),
side = if_else(str_detect(parameter, "h\\(\\-?[0-9]"), "left",
if_else(str_detect(parameter, "h\\([A-Z]"),"right","")),
weight = if_else(side == "left", spline,
if_else(side == "right", 1-spline, 0)),
weight = if_else(var == "DIS" & weight == 0.1347920, 0.0838145, weight),
weight = if_else(var == "CRIM" & weight == 0.9450162, 0.1653472, weight),
splineWeight = coef*weight)
rownames(terms) <- NULL
I also created two visualizations that attempt to get at the interaction between spline points and their coefficients, due to the multiplicative and partial nature of these piecewise functions
ggplot(terms, aes(x = var, y = splineWeight)) +
geom_segment(aes(x=var, xend=var, y=0, yend=splineWeight)) +
geom_point(color = "coral", size = 3) + coord_flip() + theme_minimal()
ggplot(terms, aes(x = var, y = coef)) +
geom_segment(aes(x=var, xend=var, y=0, yend=coef)) +
geom_point(color = "coral", size = 3) + geom_hline(aes(yintercept = 0)) + coord_flip() + theme_minimal()
So from this all, I’m going to approximate the MARS model using the spline functionality of the rms OLS function.
MARSspline = ols(pred~lsp(LSTAT,0.1324500) +
lsp(RM,0.5625120) +
lsp(PTRATIO,0.2234040) +
lsp(B,0.9966210) +
lsp(CRIM,0.0549838) +
# lsp(CRIM,0.2203310) +
lsp(NOX,0.5061730) +
# lsp(DIS,0.0509775) +
lsp(DIS,0.1347920) + CHAS, data = train.pred.norm)
MARSspline
## Linear Regression Model
##
## ols(formula = pred ~ lsp(LSTAT, 0.13245) + lsp(RM, 0.562512) +
## lsp(PTRATIO, 0.223404) + lsp(B, 0.996621) + lsp(CRIM, 0.0549838) +
## lsp(NOX, 0.506173) + lsp(DIS, 0.134792) + CHAS, data = train.pred.norm)
##
## Model Likelihood Discrimination
## Ratio Test Indexes
## Obs 354 LR chi2 839.31 R2 0.907
## sigma2.5634 d.f. 15 R2 adj 0.902
## d.f. 338 Pr(> chi2) 0.0000 g 8.356
##
## Residuals
##
## Min 1Q Median 3Q Max
## -9.4765 -1.3674 -0.1197 1.3058 14.6002
##
##
## Coef S.E. t Pr(>|t|)
## Intercept 45.3694 2.4827 18.27 <0.0001
## LSTAT -73.9885 7.7099 -9.60 <0.0001
## LSTAT' 55.9523 8.0428 6.96 <0.0001
## RM -2.7659 2.2431 -1.23 0.2184
## RM' 44.0836 3.9190 11.25 <0.0001
## PTRATIO -13.8313 5.5438 -2.49 0.0131
## PTRATIO' 8.7777 5.8378 1.50 0.1336
## B 3.4854 0.7529 4.63 <0.0001
## B' -270.2253 99.1829 -2.72 0.0068
## CRIM 31.6841 11.6964 2.71 0.0071
## CRIM' -48.1752 12.3773 -3.89 0.0001
## NOX -7.7015 2.1678 -3.55 0.0004
## NOX' -5.1858 3.0702 -1.69 0.0921
## DIS -44.3073 6.6298 -6.68 <0.0001
## DIS' 39.2226 6.6869 5.87 <0.0001
## CHAS 2.2448 0.5623 3.99 <0.0001
##
ddist <- datadist(train.pred.norm %>% select(-c(id,INDUS,AGE,TAX,MEDV)))
options(datadist = "ddist")
plot(nomogram(MARSspline))
The image again, with a better aspect ratio:
MARSNomogram
And finally, a slightly more intuitive sense of the MARS equations, here from the readout:
library(knitr)
kable(terms)
| parameter | coef | spline | var | side | weight | splineWeight |
|---|---|---|---|---|---|---|
| (Intercept) | 9.145261 | NA | Intercept | 0.0000000 | 0.0000000 | |
| h(LSTAT-0.13245) | -16.278165 | 0.1324500 | LSTAT | right | 0.8675500 | -14.1221222 |
| h(0.13245-LSTAT) | 72.950209 | 0.1324500 | LSTAT | left | 0.1324500 | 9.6622552 |
| h(RM-0.562512) | 41.234692 | 0.5625120 | RM | right | 0.4374880 | 18.0396828 |
| h(PTRATIO-0.223404) | -4.902470 | 0.2234040 | PTRATIO | right | 0.7765960 | -3.8072389 |
| h(0.223404-PTRATIO) | 16.719846 | 0.2234040 | PTRATIO | left | 0.2234040 | 3.7352804 |
| h(B-0.996621) | -279.822919 | 0.9966210 | B | right | 0.0033790 | -0.9455216 |
| h(0.996621-B) | -2.234310 | 0.9966210 | B | left | 0.9966210 | -2.2267605 |
| h(CRIM-0.0549838) | -39.839538 | 0.0549838 | CRIM | right | 0.1653472 | -6.5873561 |
| h(0.0549838-CRIM) | -37.035513 | 0.0549838 | CRIM | left | 0.0549838 | -2.0363532 |
| h(NOX-0.506173) | -11.440907 | 0.5061730 | NOX | right | 0.4938270 | -5.6498286 |
| h(0.506173-NOX) | 6.824287 | 0.5061730 | NOX | left | 0.5061730 | 3.4542699 |
| h(DIS-0.134792) | -217.594587 | 0.1347920 | DIS | right | 0.8652080 | -188.2645778 |
| h(0.134792-DIS) | 222.690190 | 0.1347920 | DIS | left | 0.0838145 | 18.6646669 |
| h(DIS-0.0509775) | 212.251353 | 0.0509775 | DIS | right | 0.9490225 | 201.4313097 |
| h(CRIM-0.220331) | 31.685813 | 0.2203310 | CRIM | right | 0.7796690 | 24.7044464 |
| CHAS | 1.674235 | NA | CHAS | 0.0000000 | 0.0000000 |
And here, in equation form. Ideally, these would be incorporated with the scaled nomogram into an overall visualization.
\[\text{prediction} = 9.15 - 16.28\text{max}(0, \text{LSTAT}-0.13)\\\\+ 75.95\text{max}(0, 0.13-\text{LSTAT})\\\\+ 41.23\text{max}(0, \text{RM}-0.56)\\\\ - 4.90\text{max}(0, \text{PTRATIO}-0.22)\\\\ + 16.72\text{max}(0, 0.22-\text{PTRATIO})\\\\ - 279.82\text{max}(0, \text{B}-0.997)\\\\ - 2.23\text{max}(0, 0.997-\text{B})\\\\- 39.84\text{max}(0, \text{CRIM}-0.055)\\\\ - 37.04\text{max}(0, 0.055-\text{CRIM})\\\\- 11.44\text{max}(0, \text{NOX}-0.51)\\\\ +6.82\text{max}(0, 0.51-\text{NOX})\\\\- 217.59\text{max}(0, \text{DIS}-0.13)\\\\ +222.69\text{max}(0, 0.13-\text{DIS})\\\\+ 212.25\text{max}(0, \text{DIS}-0.05)\\\\ + 31.69\text{max}(0,\text{CRIM-0.22})\\\\ + 1.67*\text{CHAS}\]
Many of my results can be found in the plots in the sections above, but loosely summarized, there is a long way to go to making this methodology possible, in terms of building up the visualization space. Many of the existing tools are limited and somewhat counterintutive to use. More of this will be discussed in the next section.
Beyond this, I think I was able to explore some more of the visualization space when it comes to not only visualizing MARS models but in exploring the regression visualization space more generally. There were some existing tool limitations, and also in the scale of this project, in terms of what I could build in a reasonable time frame.
I think the scaling on the nomograms worked very effectively for MARS models, even if the visualization space was somewhat limited. It would be difficult to apply this kind of solution without explanation and acclimation to using the visualization, but I think a better and interactive interface for it could really help the functionality here.
The current version that exists of dynamic nomograms are… not very reminiscent of the visualization tool of nomograms themselves. It shows the results, and through a combination of querying multiple different categories, you can get to something slightly closer to an actual nomogram, but not quite.
I would be interested in creating something that’s closer to the functionality of an actual nomogram. One of the appeals of a nomogram as an interpretability tool is the way you can use the nomogram to literally complete the calculation, to a certain degree of accuracy. It is a visualization tool intended for calculation and intuitive interpretation.
Additionally, the nomogram function currently doesn’t work for more complicated kinds of splines, and specifcially the MARS technique that I’ve explored for this project. Here, the spline points were located with MARS through a separate modeling process, and then transferred to the nomogram function. Because I was using a pre-built function for the purposes of this project, I wasn’t able to fully engineer a new implementation that would work with MARS or even other kinds of models beyond those implemented in the rms package. While rms is an extensive package with a substantial amount of modeling functionality, the nomogram function only works with models from the package, meaning that other kinds of models, such as MARS, can’t be visualized with this function simply.
I would want to use the nomogram base as a way to implement MARS visually, the beginnings of which we can see from my previous visualizations, and also incorporate the interactive functionally mentioned above.